/**********************************************************************
  Copyright(c) 2022-2023 Arm Corporation All rights reserved.

  Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions
  are met:
    * Redistributions of source code must retain the above copyright
      notice, this list of conditions and the following disclaimer.
    * Redistributions in binary form must reproduce the above copyright
      notice, this list of conditions and the following disclaimer in
      the documentation and/or other materials provided with the
      distribution.
    * Neither the name of Arm Corporation nor the names of its
      contributors may be used to endorse or promote products derived
      from this software without specific prior written permission.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**********************************************************************/
#ifndef MB_MGR_SNOW3G_SUBMIT_FLUSH_AARCH64_H
#define MB_MGR_SNOW3G_SUBMIT_FLUSH_AARCH64_H

#include "include/ipsec_ooo_mgr.h"
#include "snow3g_internal.h"
#include "snow3g.h"
#include <arm_neon.h>
#include <assert.h>
#include <string.h>
#ifdef SAFE_PARAM
#include "error.h"
#endif

#define UNUSED_LANE_MASK_BITS       4
#define UNUSED_LANE_MASK            0xF

#if SNOW3G_MB_MAX_LANES_SIMD == 4
#define INIT_DONE_MASK  0x0F
#elif SNOW3G_MB_MAX_LANES_SIMD == 8
#define INIT_DONE_MASK  0xFF
#endif
#define INIT_ALL_DONE   INIT_DONE_MASK

#define JOB_IS_COMPLETED(state, i)  \
        (((state->job_in_lane[i]) != NULL) && (state->args.byte_length[i] == 0))
#define JOB_NOT_INITIALIZED(state, i) \
        ((state->args.INITIALIZED[i] == 0))
#define JOB_INITIALIZED(state, i) \
        ((state->args.INITIALIZED[i] == 1))
#define JOB_IS_NULL(state, i) \
        (state->job_in_lane[i] == NULL)


IMB_JOB *SUBMIT_JOB_SNOW3G_UEA2(IMB_MGR *state, IMB_JOB *job);
IMB_JOB *FLUSH_JOB_SNOW3G_UEA2(IMB_MGR *state);
IMB_JOB *SUBMIT_JOB_SNOW3G_UIA2(MB_MGR_SNOW3G_OOO *state, IMB_JOB *job);
IMB_JOB *FLUSH_JOB_SNOW3G_UIA2(MB_MGR_SNOW3G_OOO *state);

void SNOW3G_F8_1_BUFFER_STREAM_JOB(void *pCtx,
                                   const void *pBufferIn,
                                   void *pBufferOut,
                                   const uint32_t lengthInBytes);

void SNOW3G_F8_MULTI_BUFFER_INITIALIZE_JOB(void *pCtx,
                                           const snow3g_key_schedule_t **pKeySched,
                                           const uint8_t **pIV);

void SNOW3G_F8_MULTI_BUFFER_STREAM_JOB(void *pCtx,
                                       const uint8_t **pBufferIn,
                                       uint8_t **pBufferOut,
                                       const uint32_t lengthInBytes);

void SNOW3G_F9_MULTI_BUFFER_KEYSTREAM_JOB(void *pCtx,
                                          uint32_t *ks);

void SNOW3G_F9_1_BUFFER_DIGEST_JOB(const uint32_t z[5],
                                   const void *pBufferIn,
                                   const uint64_t lengthInBits,
                                   void *pDigest);

static void snow3g_mb_mgr_insert_uea2_job(MB_MGR_SNOW3G_OOO *state, IMB_JOB *job)
{
    uint64_t used_lane_idx = state->unused_lanes & UNUSED_LANE_MASK;
    assert(used_lane_idx < SNOW3G_MB_MAX_LANES_SIMD);
    state->unused_lanes =  state->unused_lanes >> UNUSED_LANE_MASK_BITS;
    state->num_lanes_inuse++;
    state->args.iv[used_lane_idx] = job->iv;
    state->args.keys[used_lane_idx] = job->enc_keys;
    state->args.in[used_lane_idx] = job->src + (job->cipher_start_src_offset_in_bits / 8);
    state->args.out[used_lane_idx] = job->dst + (job->cipher_start_src_offset_in_bits / 8);
    state->args.byte_length[used_lane_idx] = job->msg_len_to_cipher_in_bits / 8;
    state->args.INITIALIZED[used_lane_idx] = 0;
    state->lens[used_lane_idx] = job->msg_len_to_cipher_in_bits / 8;

    state->job_in_lane[used_lane_idx] = job;
}

static void snow3g_mb_mgr_insert_uia2_job(MB_MGR_SNOW3G_OOO *state, IMB_JOB *job)
{
    uint64_t used_lane_idx = state->unused_lanes & UNUSED_LANE_MASK;
    assert(used_lane_idx < SNOW3G_MB_MAX_LANES_SIMD);
    state->unused_lanes =  state->unused_lanes >> UNUSED_LANE_MASK_BITS;
    state->num_lanes_inuse++;
    state->args.iv[used_lane_idx] = job->u.SNOW3G_UIA2._iv;
    state->args.keys[used_lane_idx] = job->u.SNOW3G_UIA2._key;
    state->args.in[used_lane_idx] = job->src + job->hash_start_src_offset_in_bytes;
    state->args.out[used_lane_idx] = job->auth_tag_output;
    state->args.INITIALIZED[used_lane_idx] = 0;
    state->lens[used_lane_idx] = job->msg_len_to_hash_in_bits;
    state->init_done = state->init_done & (~(1 << used_lane_idx) & INIT_DONE_MASK);

    state->job_in_lane[used_lane_idx] = job;
}

static IMB_JOB *snow3g_mb_mgr_free_uea2_job(MB_MGR_SNOW3G_OOO *state)
{
    IMB_JOB *ret = NULL;

    for (int i = 0; i <= SNOW3G_MB_MAX_LANES_SIMD; i++) {
        if (JOB_IS_COMPLETED(state, i)) {
            ret = state->job_in_lane[i];
            ret->status |= IMB_STATUS_COMPLETED_CIPHER;
            state->job_in_lane[i] = NULL;
            state->unused_lanes = state->unused_lanes << UNUSED_LANE_MASK_BITS;
            state->unused_lanes |= i;
            state->num_lanes_inuse--;
            state->lens[i] = 0;
            state->args.INITIALIZED[i] = 0;
#ifdef SAFE_DATA
            uint32_t* key_state = (uint32_t *)&(state->args.LFSR_0[0]);
            for (int k = 0; k < (16 + 3); k++) {
                key_state[k * SNOW3G_MB_MAX_LANES_SIMD + i] = 0;
            }
#endif
            break;
        }
    }

    return ret;
}

static IMB_JOB *snow3g_mb_mgr_free_uia2_job(MB_MGR_SNOW3G_OOO *state, int i)
{
    IMB_JOB *ret = NULL;
    assert(!JOB_IS_NULL(state, i));
    ret = state->job_in_lane[i];
    ret->status |= IMB_STATUS_COMPLETED_AUTH;
    state->job_in_lane[i] = NULL;
    state->unused_lanes = state->unused_lanes << UNUSED_LANE_MASK_BITS;
    state->unused_lanes |= i;
    state->num_lanes_inuse--;
    state->lens[i] = 0;
    state->args.INITIALIZED[i] = 0;
    state->init_done = state->init_done & (~(1 << i) & INIT_DONE_MASK);

#ifdef SAFE_DATA
    uint32_t* key_state = (uint32_t *)&(state->args.LFSR_0[0]);
    for (int k = 0; k < (16 + 3); k++) {
        key_state[k * SNOW3G_MB_MAX_LANES_SIMD + i] = 0;
    }
    for (int k = 0; k < 5; k++) {
        state->ks[i * 5 + k] = 0;
    }
#endif

    return ret;
}

__forceinline
void cpy_state_to_ctx1(snow3gKeyStateMulti_t* state, snow3gKeyState1_t* ctx, const int num_lane) {
    uint32_t iLFSR_X = state->iLFSR_X;
    uint32_t *src = (uint32_t *)&(state->LFSR_X[0]);
    uint32_t *dst = (uint32_t *)&(ctx->LFSR_S[0]);
    for (int i = 0; i < 16; i++) {
        dst[i] = src[((i + iLFSR_X) % 16) * SNOW3G_MB_MAX_LANES_SIMD + num_lane];
    }
    for (int i = 16; i < 19; i++) {
        dst[i] = src[i * SNOW3G_MB_MAX_LANES_SIMD + num_lane];
    }
}

__forceinline
void cpy_newly_intialized_ctx_to_state(snow3gKeyStateMulti_t* new, MB_MGR_SNOW3G_OOO* state) {
    snow3gKeyStateMulti_t* ctx = (snow3gKeyStateMulti_t *)&(state->args.LFSR_0[0]);
    uint32_t* dst = (uint32_t *)&(ctx->LFSR_X[0]);
    uint32_t* src = (uint32_t *)&(new->LFSR_X[0]);
    uint32_t dst_iLFSR = ctx->iLFSR_X;
    uint32_t src_iLFSR = new->iLFSR_X;
    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        if (JOB_NOT_INITIALIZED(state, i)) {
            for (int j = 0; j < (16 + 3); j++) {
                dst[((j + dst_iLFSR) % 16) * SNOW3G_MB_MAX_LANES_SIMD + i] =
                src[((j + src_iLFSR) % 16) * SNOW3G_MB_MAX_LANES_SIMD + i];
            }
            for (int j = 16; j < 19; j++) {
                dst[j * SNOW3G_MB_MAX_LANES_SIMD + i] = src[j * SNOW3G_MB_MAX_LANES_SIMD + i];
            }
            state->args.INITIALIZED[i] = 1;
        }
    }
}

IMB_JOB *SUBMIT_JOB_SNOW3G_UEA2(IMB_MGR *state,
                                       IMB_JOB *job)
{
#ifdef SAFE_PARAM
        /* reset error status */
        if (imb_errno != 0)
                imb_set_errno(NULL, 0);

        if (job->enc_keys == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_EXP_KEY);
                return NULL;
        }
        if (job->iv == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_IV);
                return NULL;
        }

        if (job->src == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_SRC);
                return NULL;
        }
        if (job->dst == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_DST);
                return NULL;
        }
        if ((job->msg_len_to_cipher_in_bits == 0) ||
            (job->msg_len_to_cipher_in_bits > SNOW3G_MAX_BITLEN)) {
                imb_set_errno(NULL, IMB_ERR_CIPH_LEN);
                return NULL;
        }
#endif

    MB_MGR_SNOW3G_OOO *snow3g_state = state->snow3g_uea2_ooo;
    uint32_t msg_bitlen = job->msg_len_to_cipher_in_bits;
    uint32_t msg_bitoff = job->cipher_start_src_offset_in_bits;

    /* Use bit length API if
     * - msg length is not a multiple of bytes
     * - bit offset is not a multiple of bytes
     */
    if ((msg_bitlen & 0x07) || (msg_bitoff & 0x07)) {
        IMB_SNOW3G_F8_1_BUFFER_BIT(state, job->enc_keys, job->iv, job->src,
                                   job->dst, msg_bitlen, msg_bitoff);
        job->status |= IMB_STATUS_COMPLETED_CIPHER;
        return job;
    }

    IMB_JOB *ret = NULL;

    snow3g_mb_mgr_insert_uea2_job(snow3g_state, job);

    ret = snow3g_mb_mgr_free_uea2_job(snow3g_state);
    if (ret != NULL)
        return ret;

    if(snow3g_state->num_lanes_inuse < SNOW3G_MB_MAX_LANES_SIMD)
        return NULL;

    uint32_t min_word_len = UINT32_MAX;
    snow3gKeyStateMulti_t *pCtx = (snow3gKeyStateMulti_t *)&(snow3g_state->args.LFSR_0[0]);
    snow3gKeyStateMulti_t tmp_ctx;

    SNOW3G_F8_MULTI_BUFFER_INITIALIZE_JOB(&tmp_ctx,
                                          (const snow3g_key_schedule_t **)snow3g_state->args.keys,
                                          (const uint8_t**)snow3g_state->args.iv);
    cpy_newly_intialized_ctx_to_state(&tmp_ctx, snow3g_state);

    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        min_word_len = (min_word_len < snow3g_state->args.byte_length[i] / SNOW3G_4_BYTES) ?
                       min_word_len : snow3g_state->args.byte_length[i] / SNOW3G_4_BYTES;
    }

    SNOW3G_F8_MULTI_BUFFER_STREAM_JOB(pCtx,
                                      (const uint8_t **)snow3g_state->args.in,
                                      (uint8_t **)snow3g_state->args.out,
                                      min_word_len * SNOW3G_4_BYTES);

    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        snow3g_state->args.byte_length[i] -= min_word_len * SNOW3G_4_BYTES;
    }

    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        // if less than one word left, finish job here.
        if (snow3g_state->args.byte_length[i] < SNOW3G_4_BYTES &&
            snow3g_state->args.byte_length[i] != 0) {
            snow3gKeyState1_t ctx_1;
            cpy_state_to_ctx1(pCtx, &ctx_1, i);
            SNOW3G_F8_1_BUFFER_STREAM_JOB(&ctx_1, snow3g_state->args.in[i],
                                          snow3g_state->args.out[i],
                                          snow3g_state->args.byte_length[i]);
            snow3g_state->args.byte_length[i] = 0;
        }
    }

    ret = snow3g_mb_mgr_free_uea2_job(snow3g_state);

#ifdef SAFE_DATA
    // data has been cleard in snow3g_mb_mgr_free_uea2_job.
#endif

    return ret;
}

IMB_JOB *FLUSH_JOB_SNOW3G_UEA2(IMB_MGR *state)
{
    IMB_JOB *ret = NULL;
    MB_MGR_SNOW3G_OOO *snow3g_state = state->snow3g_uea2_ooo;
    ret = snow3g_mb_mgr_free_uea2_job(snow3g_state);

    if (ret != NULL) {
        return ret;
    }

    for (int i = 0; i <= SNOW3G_MB_MAX_LANES_SIMD; i++) {
        if (snow3g_state->job_in_lane[i] != NULL && snow3g_state->lens[i] != 0) {
            ret = snow3g_state->job_in_lane[i];

            if (JOB_NOT_INITIALIZED(snow3g_state, i)) {
                // if not initialized
                IMB_SNOW3G_F8_1_BUFFER(state, snow3g_state->args.keys[i],
                                       snow3g_state->args.iv[i],
                                       snow3g_state->args.in[i],
                                       snow3g_state->args.out[i],
                                       snow3g_state->args.byte_length[i]);
            } else {
                snow3gKeyState1_t ctx;
                snow3gKeyStateMulti_t* state = (snow3gKeyStateMulti_t*)&(snow3g_state->args.LFSR_0[0]);
                cpy_state_to_ctx1(state, &ctx, i);
                SNOW3G_F8_1_BUFFER_STREAM_JOB(&ctx, snow3g_state->args.in[i],
                                              snow3g_state->args.out[i],
                                              snow3g_state->args.byte_length[i]);
            }

            ret->status |= IMB_STATUS_COMPLETED_CIPHER;
            snow3g_state->lens[i] = 0;
            snow3g_state->job_in_lane[i] = NULL;
            snow3g_state->unused_lanes = snow3g_state->unused_lanes << UNUSED_LANE_MASK_BITS;
            snow3g_state->unused_lanes |= i;
            snow3g_state->num_lanes_inuse--;
            snow3g_state->args.byte_length[i] = 0;
            snow3g_state->args.INITIALIZED[i] = 0;
#ifdef SAFE_DATA
            uint32_t* key_state = (uint32_t *)&(snow3g_state->args.LFSR_0[0]);
            for (int k = 0; k < (16 + 3); k++) {
                key_state[k * SNOW3G_MB_MAX_LANES_SIMD + i] = 0;
            }
#endif
            return ret;
        }
    }
    return NULL;
}

IMB_JOB *SUBMIT_JOB_SNOW3G_UIA2(MB_MGR_SNOW3G_OOO *state,
                                IMB_JOB *job)
{
#ifdef SAFE_PARAM
        /* reset error status */
        if (imb_errno != 0)
                imb_set_errno(NULL, 0);

        if (job->u.SNOW3G_UIA2._key == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_EXP_KEY);
                return NULL;
        }
        if (job->u.SNOW3G_UIA2._iv == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_IV);
                return NULL;
        }

        if (job->src == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_SRC);
                return NULL;
        }
        if (job->auth_tag_output == NULL) {
                imb_set_errno(NULL, IMB_ERR_NULL_AUTH);
                return NULL;
        }
        if ((job->msg_len_to_hash_in_bits == 0) ||
            (job->msg_len_to_hash_in_bits > SNOW3G_MAX_BITLEN)) {
                imb_set_errno(NULL, IMB_ERR_AUTH_LEN);
                return NULL;
        }
#endif
    MB_MGR_SNOW3G_OOO *snow3g_state = state;

    IMB_JOB *ret = NULL;

    snow3g_mb_mgr_insert_uia2_job(snow3g_state, job);

    if (snow3g_state->num_lanes_inuse < SNOW3G_MB_MAX_LANES_SIMD)
        return NULL;

    if (snow3g_state->init_done == 0) {
        // all lanes are not initialized.
        snow3gKeyStateMulti_t ctx;
        SNOW3G_F8_MULTI_BUFFER_INITIALIZE_JOB(&ctx,
                                              (const snow3g_key_schedule_t **)snow3g_state->args.keys,
                                              (const uint8_t**)snow3g_state->args.iv);
        SNOW3G_F9_MULTI_BUFFER_KEYSTREAM_JOB(&ctx,
                                             snow3g_state->ks);
        snow3g_state->init_done = INIT_ALL_DONE;
    }

    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        if (snow3g_state->init_done & (1 << i)) {
            // pick a initialized lane
            SNOW3G_F9_1_BUFFER_DIGEST_JOB(&snow3g_state->ks[i*5], snow3g_state->args.in[i],
                                          snow3g_state->lens[i], snow3g_state->args.out[i]);
            ret = snow3g_mb_mgr_free_uia2_job(snow3g_state, i);
            break;
        }
    }
    return ret;
}

IMB_JOB *FLUSH_JOB_SNOW3G_UIA2(MB_MGR_SNOW3G_OOO *state)
{
    IMB_JOB *ret = NULL;
    MB_MGR_SNOW3G_OOO *snow3g_state = state;

    if (snow3g_state->num_lanes_inuse == 0) {
        // empty
        return NULL;
    }
    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        if (snow3g_state->init_done & (1<<i)) {
            // pick a initialized lane
            SNOW3G_F9_1_BUFFER_DIGEST_JOB(&snow3g_state->ks[i*5], snow3g_state->args.in[i],
                                          snow3g_state->lens[i], snow3g_state->args.out[i]);
            ret = snow3g_mb_mgr_free_uia2_job(snow3g_state, i);
            return ret;
        }
    }
    int lane_idx;
    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        if (!JOB_IS_NULL(snow3g_state, i)) {
            snow3g_state->init_done |= (1<<i);
            lane_idx = i;
        }
    }
    for (int i = 0; i < SNOW3G_MB_MAX_LANES_SIMD; i++) {
        // copy keys and ivs to empty lane
        if (JOB_IS_NULL(snow3g_state, i)) {
            snow3g_state->args.keys[i] = snow3g_state->args.keys[lane_idx];
            snow3g_state->args.iv[i] = snow3g_state->args.iv[lane_idx];
        }
    }

    snow3gKeyStateMulti_t ctx;
    SNOW3G_F8_MULTI_BUFFER_INITIALIZE_JOB(&ctx,
                                          (const snow3g_key_schedule_t **)snow3g_state->args.keys,
                                          (const uint8_t **)snow3g_state->args.iv);
    SNOW3G_F9_MULTI_BUFFER_KEYSTREAM_JOB(&ctx,
                                         snow3g_state->ks);
    // pick a initialized lane
    SNOW3G_F9_1_BUFFER_DIGEST_JOB(&snow3g_state->ks[lane_idx*5], snow3g_state->args.in[lane_idx],
                                  snow3g_state->lens[lane_idx], snow3g_state->args.out[lane_idx]);
    ret = snow3g_mb_mgr_free_uia2_job(snow3g_state, lane_idx);
    return ret;
}

#endif // MB_MGR_SNOW3G_SUBMIT_FLUSH_AARCH64_H
